{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import torch\n",
    "\n",
    "import pandas as pd\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    'backbone': 'microsoft/deberta-v3-base',\n",
    "    'model_path': './cache',\n",
    "    'max_length': 512,\n",
    "\n",
    "    'device': 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'(MaxRetryError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /microsoft/deberta-v3-base/resolve/main/tokenizer_config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000002B89411B640>, 'Connection to huggingface.co timed out. (connect timeout=10)'))\"), '(Request ID: cfe61fd4-1174-4fad-948b-b72a73c96259)')' thrown while requesting HEAD https://huggingface.co/microsoft/deberta-v3-base/resolve/main/tokenizer_config.json\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "d:\\ProgramData\\Anaconda3\\lib\\site-packages\\transformers\\convert_slow_tokenizer.py:454: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "# import model\n",
    "# from Frozenbert.base_model import Frozenbert\n",
    "model_path = './checkpoints/FrozenBert-epochs3-val_mcrmse0.4411.pth'\n",
    "model = torch.load(model_path)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(config['backbone'], cache_dir=config['model_path'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final Test Predictions:\n",
      "\n",
      "essay0---cohesion: 3.50, syntax: 3.50, vocabulary: 3.50, phraseology: 3.50, grammar: 3.50, conventions: 3.50\n",
      "essay1---cohesion: 3.00, syntax: 2.50, vocabulary: 3.00, phraseology: 2.50, grammar: 2.00, conventions: 2.50\n",
      "essay2---cohesion: 3.00, syntax: 3.00, vocabulary: 3.00, phraseology: 3.00, grammar: 2.50, conventions: 2.50\n"
     ]
    }
   ],
   "source": [
    "from utils.classes import EssayDataset\n",
    "from utils.func import to_scores\n",
    "\n",
    "# data frame\n",
    "df = pd.read_csv('./input/feedback-prize-english-language-learning/train.csv')\n",
    "test_df = pd.read_csv(\n",
    "    './input/feedback-prize-english-language-learning/test.csv')\n",
    "test_ds = EssayDataset(test_df, config, tokenizer=tokenizer, is_test=True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)\n",
    "\n",
    "model.eval()\n",
    "preds = []\n",
    "\n",
    "# iterate test_loader\n",
    "for (inputs) in test_loader:\n",
    "    inputs = {k: inputs[k].to(device=config['device'])\n",
    "                for k in inputs.keys()}\n",
    "\n",
    "    outputs = model(inputs)\n",
    "    preds.append(outputs.detach().cpu())\n",
    "\n",
    "preds = torch.concat(preds)\n",
    "\n",
    "preds = to_scores(preds)\n",
    "\n",
    "print(\"Final Test Predictions:\\n\")\n",
    "for i in range(len(test_loader)):\n",
    "    print('essay%d---cohesion: %.2f, syntax: %.2f, vocabulary: %.2f, phraseology: %.2f, grammar: %.2f, conventions: %.2f'\n",
    "          % (i, preds[i][0], preds[i][1], preds[i][2], preds[i][3], preds[i][4], preds[i][5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([     1,   2651,   9805,    704,   1603,    272,    307,   1632,    269,\n",
       "           1496,    361,    400,    301,    295,    350,    619,    335,    301,\n",
       "            281,    489,    653,    491,    261,    476,   1757,    264,    291,\n",
       "           1548,    260,    879,    355,    504,    272,    278,    269,    489,\n",
       "            493,    264,    489,    282,    653,    491,    354,    264,    282,\n",
       "          20627,    263,    324,    942,    288,    305,    264,   9603,    385,\n",
       "            277,    290,   3968,    260,    489,    653,    491,   4573,    274,\n",
       "            551,    266,   1100,    265,  14285,    261,   1632,    682,    274,\n",
       "            409,    264,   6525,    264,    333,    402,    310,    354,    339,\n",
       "            274,    464,    261,   3606,    295,    327,    799,    491,    641,\n",
       "            262,    384,    260,    279,    362,    919,    272,    489,    653,\n",
       "            269,    493,    354,    653,    942,    269,    272,    278,   1360,\n",
       "            274,    266,   1100,    265,    274,  25330,    491,    260,    335,\n",
       "            274,   1723,    266,   2204,    274,    551,   1473,    275,    955,\n",
       "            267,   1266,    272,    274,  13621,    491,    263,   1756,    446,\n",
       "            260,    270,    738,    261,   9163,    274,   3665,    292,    563,\n",
       "            263,    825,    277,    264,  19777,    261,   3606,    551,   1473,\n",
       "            274,    327,    391,    272,    274,   1103,    283,    657,    283,\n",
       "            274,    295,    264,    850,    290,  16518,    263,   1826,  19777,\n",
       "            289,    402,    264,    262,    609,    447,    264,    374,    260,\n",
       "          13562,    261, 111178,    510,    290,   1593,    682,    274,    409,\n",
       "            264,   6525,    264,    333,    310,    260,    335,    274,    850,\n",
       "            491,    274,    372,    428,    264,    955,    261,    309,   2931,\n",
       "            273,    464,    291,    354,    339,    995,    295,    584,    464,\n",
       "            275,    312,    432,    302,    309,    260, 108139,    510,    491,\n",
       "            387,    365,    274,    409,    264,    333,    402,    310,    260,\n",
       "            434,    738,    261,    335,    274,   1723,    405,    283,    563,\n",
       "            393,    734,    825,    321,    265,    262,    669,    263,    685,\n",
       "            277,    274,    451,    261,   1632,   1360,    274,    266,   1100,\n",
       "            265,    274,    411,    299,   2142,    604,    263,    274,    591,\n",
       "            262,   5890,    264,    685,    277,    274,    451,    263,    327,\n",
       "            264,   2059,    274,    451,    432,    361,    274,    409,    278,\n",
       "            264,    282,    260,  28318,    278,   1360,    274,    262,   1177,\n",
       "            264,    424,   1026,    267,    290,   1206,    260,   2872,    278,\n",
       "            527,    274,    262,   1062,    264,    799,    353,    263,   1772,\n",
       "            576,    641,    262,    384,    272,    274,    372,    282,    526,\n",
       "            264,    380,    267,    290,    451,    432,    260,    293,  25330,\n",
       "            491,    274,    295,    799,    361,    264,    408,    274,    934,\n",
       "            264,   5779,    353,    263,    493,    479,    267,    274,    432,\n",
       "            260,    434,    738,    366,   2319,   5078,    563,    274,    296,\n",
       "            413,    305,    272,    274,   1859,    263,    274,    295,   1514,\n",
       "            278,    278,    347,   3322,    265,    290,    432,    264,    408,\n",
       "            274,   5376,    267,    432,    260,    367,    295,    327,   1514,\n",
       "            278,    264,    262,    688,    274,   5012,    289,    402,    264,\n",
       "            350,    266,   1026,   1082,    260,   1589,    347,    372,   8749,\n",
       "            275,    351,    263,    504,    272,  42347,    269,   2136,    335,\n",
       "            833,    264,   5779,    491,    401,    278,    296,    298,   1138,\n",
       "           2148,    277,    274,    366,   6356,    291,   6834,    264,    553,\n",
       "            269,    272,    278,   4573,    274,    638,    266,    493,    604,\n",
       "            267,    290,    406,    264,    406,    432,    260,    267,   4533,\n",
       "            489,    653,    269,    493,    354,    653,    942,    401,    278,\n",
       "           1530,    274,   1912,    310,    267,    274,    432,    261,  51451,\n",
       "            274,    551,   5912,    267,    432,    261,    263,   1057,    272,\n",
       "            274,    295,    489,   6525,    264,    333,    310,    267,    274,\n",
       "            432,      2,      0,      0,      0,      0,      0,      0,      0,\n",
       "              0,      0,      0,      0,      0,      0,      0,      0,      0,\n",
       "              0,      0,      0,      0,      0,      0,      0,      0]),\n",
       " 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0])}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# view outputs of tokenizer\n",
    "test_ds[2]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 any text "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.classes import EssayDataset\n",
    "from utils.func import to_scores\n",
    "# test result after training\n",
    "\n",
    "\n",
    "def test(test_loader):\n",
    "\n",
    "    preds = []\n",
    "\n",
    "    # iterate test_loader\n",
    "    for (inputs) in test_loader:\n",
    "        inputs = {k: inputs[k].to(device=config['device'])\n",
    "                  for k in inputs.keys()}\n",
    "\n",
    "        outputs = model(inputs)\n",
    "        preds.append(outputs.detach().cpu())\n",
    "\n",
    "    preds = torch.concat(preds)\n",
    "    return preds\n",
    "\n",
    "\n",
    "def test_essay(essay):\n",
    "    data = [[0, essay]]\n",
    "\n",
    "    test_df = pd.DataFrame(\n",
    "        data, columns=['text_id', 'full_text'], dtype=object)\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(config['backbone'], cache_dir=config['model_path'])\n",
    "    test_ds = EssayDataset(test_df, config, tokenizer=tokenizer, is_test=True)\n",
    "\n",
    "    test_loader = torch.utils.data.DataLoader(test_ds,\n",
    "                                              batch_size=1,\n",
    "                                              shuffle=True,\n",
    "                                              num_workers=0,\n",
    "                                              pin_memory=True\n",
    "                                              )\n",
    "\n",
    "    model.eval()\n",
    "    preds = test(test_loader=test_loader)\n",
    "\n",
    "    preds = to_scores(preds)\n",
    "\n",
    "    print('cohesion: %.1f, syntax: %.1f, vocabulary: %.1f, phraseology: %.1f, grammar: %.1f, conventions: %.1f'\n",
    "          % (preds[0][0], preds[0][1], preds[0][2], preds[0][3], preds[0][4], preds[0][5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'(MaxRetryError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /microsoft/deberta-v3-base/resolve/main/tokenizer_config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000002B89418B640>, 'Connection to huggingface.co timed out. (connect timeout=10)'))\"), '(Request ID: fc790a5f-223b-4383-b2b2-38fd72bec7f2)')' thrown while requesting HEAD https://huggingface.co/microsoft/deberta-v3-base/resolve/main/tokenizer_config.json\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "d:\\ProgramData\\Anaconda3\\lib\\site-packages\\transformers\\convert_slow_tokenizer.py:454: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohesion: 0.5, syntax: 0.5, vocabulary: 1.0, phraseology: 0.5, grammar: 1.5, conventions: 0.5\n"
     ]
    }
   ],
   "source": [
    "# type essay here\n",
    "essay = \"yo aaa h ddwdjd ly good.\"\n",
    "test_essay(essay)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 val test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'(MaxRetryError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /microsoft/deberta-v3-base/resolve/main/tokenizer_config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000002B81EB10E80>, 'Connection to huggingface.co timed out. (connect timeout=10)'))\"), '(Request ID: f7162836-e8d6-479f-af73-121b43dd1d05)')' thrown while requesting HEAD https://huggingface.co/microsoft/deberta-v3-base/resolve/main/tokenizer_config.json\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "d:\\ProgramData\\Anaconda3\\lib\\site-packages\\transformers\\convert_slow_tokenizer.py:454: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "val_df = pd.read_csv(\n",
    "    './input/feedback-prize-english-language-learning/train.csv')\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(config['backbone'])\n",
    "\n",
    "val_ds = EssayDataset(val_df, config, tokenizer=tokenizer)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(val_ds,\n",
    "                                         batch_size=1,\n",
    "                                         shuffle=False,\n",
    "                                         num_workers=0,\n",
    "                                         pin_memory=True\n",
    "                                         )\n",
    "# val_loader.dataset.df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ca38149857347329fee27905e709707",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3911 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohesion: , syntax: , vocabulary: , phraseology: , grammar: , conventions: \n",
      "targets:  [[3.5, 4.0, 4.0, 3.5, 3.5, 4.0]]\n",
      "outputs:  [[3.189666509628296, 3.175337791442871, 3.2942020893096924, 3.1032185554504395, 3.2570724487304688, 3.431267023086548]]\n",
      "\n",
      "targets:  [[3.5, 4.0, 3.5, 3.5, 4.0, 4.0]]\n",
      "outputs:  [[3.3855955600738525, 3.5182673931121826, 3.52046275138855, 3.3622946739196777, 3.489351511001587, 3.573880910873413]]\n",
      "\n",
      "targets:  [[4.0, 3.5, 3.0, 4.0, 3.5, 4.0]]\n",
      "outputs:  [[3.5180983543395996, 3.6222243309020996, 3.8053174018859863, 3.699758768081665, 3.693772077560425, 3.711016893386841]]\n",
      "\n",
      "targets:  [[3.0, 4.0, 3.0, 4.0, 3.5, 3.0]]\n",
      "outputs:  [[3.5345354080200195, 3.6554787158966064, 3.741503953933716, 3.804649591445923, 3.7657315731048584, 3.47724986076355]]\n",
      "\n",
      "targets:  [[4.0, 4.0, 4.0, 4.0, 4.5, 3.5]]\n",
      "outputs:  [[4.054250717163086, 4.097082614898682, 4.196408748626709, 4.205399990081787, 4.119227409362793, 4.012894153594971]]\n",
      "\n",
      "targets:  [[4.0, 4.5, 5.0, 4.0, 5.0, 4.5]]\n",
      "outputs:  [[4.263969421386719, 4.341069221496582, 4.489424705505371, 4.354664325714111, 4.333625793457031, 4.471264362335205]]\n",
      "\n",
      "targets:  [[4.5, 4.0, 4.0, 4.0, 4.0, 3.5]]\n",
      "outputs:  [[4.1276631355285645, 4.01705265045166, 4.240116596221924, 4.106070041656494, 3.982333183288574, 3.8614211082458496]]\n",
      "\n",
      "avg_targets:  [[3.127077579498291, 3.0282535552978516, 3.235745429992676, 3.116849899291992, 3.032855987548828, 3.0810534954071045]]\n",
      "avg_outputs:  [[3.2234506607055664, 3.1721324920654297, 3.3899431228637695, 3.190525531768799, 3.1054060459136963, 3.2027618885040283]]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def valid(model, val_loader):\n",
    "    c = 0\n",
    "    progress = tqdm(val_loader, total=len(val_loader))\n",
    "\n",
    "    avg_targets = torch.zeros((1, 6)).to(device=config['device'])\n",
    "    avg_outputs = torch.zeros((1, 6)).to(device=config['device'])\n",
    "    print('cohesion: , syntax: , vocabulary: , phraseology: , grammar: , conventions: ')\n",
    "\n",
    "    # iterate self.val_loader\n",
    "    for (inputs, targets) in progress:\n",
    "\n",
    "        inputs = {k: inputs[k].to(device=config['device'])\n",
    "                  for k in inputs.keys()}\n",
    "\n",
    "        targets = targets['labels'].to(device=config['device'])\n",
    "\n",
    "        outputs = model(inputs)\n",
    "        # outputs=to_scores(outputs)\n",
    "\n",
    "        if c < 7 and (targets.tolist()[0][1] == 4 or targets.tolist()[0][0] == 4):\n",
    "            print(\"targets: \", targets.tolist())\n",
    "            print(\"outputs: \", outputs.tolist())\n",
    "            print()\n",
    "            c += 1\n",
    "\n",
    "        avg_targets = torch.add(avg_targets, targets)\n",
    "        avg_outputs = torch.add(avg_outputs, outputs)\n",
    "\n",
    "    avg_targets = avg_targets/len(val_loader)\n",
    "    avg_outputs = avg_outputs/len(val_loader)\n",
    "\n",
    "    print(\"avg_targets: \", avg_targets.tolist())\n",
    "    print(\"avg_outputs: \", avg_outputs.tolist())\n",
    "\n",
    "\n",
    "valid(model, val_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
